# -*- coding: utf-8 -*-
"""
Created on Tue Apr 21 14:49:36 2020

@author: xiang_yaobing
"""
from astropy.io import fits
import numpy as np
import pandas as pd
from numpy import nan as nan
from pandas import Series, DataFrame
import data_genrate_1 as dg
import os 
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
import cv2
def mkdir(path):

	folder = os.path.exists(path)

	if not folder:                   #判断是否存在文件夹如果不存在则创建为文件夹
		os.makedirs(path)            #makedirs 创建文件时如果路径不存在会创建这个路径
		print ("---  new folder...  ---")
		print ("---  OK  ---")

	else:
		print ("---  There is this folder(it had already been build)!  ---")
        
def get_data(dat):
    '''
    从fits中得到所需数据，写到指定矩阵
    '''
    data_arr = np.zeros([dat.size, 5])
    for i in range (dat.size):
        data_arr[i][0] = dat[i][0]
        data_arr[i][1] = dat[i][1] 
#        data_arr[i][2] = dat[i][5]#plx
#        data_arr[i][2] = dat[i][10]#pmRA
#        data_arr[i][3] = dat[i][12]#pmDE
    return data_arr
'''        
data=fits.open('C:\\Users\\xiang_yaobing\\Desktop\\天文\\cluster_data\\正负样本\\cluster\\berkeley19.r0.1.fits')#查看单个数据
dat = data[1].data
#data_arr=np.zeros([dat.size, 5])
data_arr = get_data(dat)
data_deal = DataFrame(data_arr)
#data_deal.dropna(axis = 0, how = 'any')  
df4 = data_deal.dropna() # to dropout the line include nan
data_deal = np.array(df4)

#np.savetxt('../data_center/q.txt', data_deal)
'''

def data_wash(data_deal, fa):#对数据进行高斯估计，去除个别噪声,返回处理后数据及数据参数列表
    '''
    loc_value: raj_max, raj_min, dej_max, dej_min
    pm_value: pm_mean +- pmra_std, pm_mean +- pmde_std
    八个变量是控制视图显示范围
    '''
    raj_max = np.max(data_deal[:,0])
    raj_min = np.min(data_deal[:,0])
    dej_max = np.max(data_deal[:,1])
    dej_min = np.min(data_deal[:,1])
    loc_value = [raj_min, raj_max, dej_min, dej_max]
#    pmra_std = np.std(data_deal[:,2],ddof = 1)
#    pmde_std = np.std(data_deal[:,3],ddof = 1)
#    pm_mean = [np.mean(data_deal[:,2]), np.mean(data_deal[:,3])]
    
#    rate = 1 
#    ##中误差控制率
#    pmra_max_value = pm_mean[0]+rate*pmra_std
#    pmra_min_value = pm_mean[0]-rate*pmra_std
#    pmde_max_value = pm_mean[1]+rate*pmde_std
#    pmde_min_value = pm_mean[1]-rate*pmde_std
    data_deal = data_deal[data_deal[:,0] < raj_max-fa]
    data_deal = data_deal[data_deal[:,0] > raj_min+fa]
    data_deal = data_deal[data_deal[:,1] < dej_max-fa]
    data_deal = data_deal[data_deal[:,1] > dej_min+fa]
    loc_value_s = [raj_min+fa, raj_max-fa, dej_min+fa, dej_max-fa]
#    pm_value = [pmra_min_value, pmra_max_value, pmde_min_value, pmde_max_value]
    return data_deal, loc_value, loc_value_s
#data_deal_wish, pm_value, loc_value = data_wash(data_deal)
#dg.cluster_subplot(data_deal_wish,loc_value, pm_value)

def save_the_data_subpicture(file_path, save_path):#读取某一文件夹数据集，保存数据多维图片到文件夹
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    for fit in filelist_open:
        print(fit)
        data=fits.open(fit)    
        dat = data[1].data      
        data_arr = get_data(dat)
        data_deal = DataFrame(data_arr)  
        df4 = data_deal.dropna() # to dropout the line include nan
        data_deal = np.array(df4)
        data_deal_wish, pm_value, loc_value = data_wash(data_deal)
        dg.cluster_subplot(data_deal_wish,loc_value, pm_value)
        fit_name = fit.split('\\')
        plt.savefig(save_path+fit_name[-1] +'.jpg')
    return 0
        

def save_fits_data_realpicture(file_path, save_path):#读取某一文件夹数据集，保存数据多维机器用图片到文件夹
    '''
    生成机器训练与判断数据
    '''
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    file_name = (file_path.split('\\'))[-2]
    #print (file_name)
    new_loc_path = save_path + file_name + '\\loc\\'
    new_vel_path = save_path + file_name + '\\vel\\'
    mkdir(new_loc_path)
    mkdir(new_vel_path)
    for fit in filelist_open:
        print(fit)
        data=fits.open(fit)    
        dat = data[1].data      
        data_arr = get_data(dat)
        data_deal = DataFrame(data_arr)  
        df4 = data_deal.dropna() # to dropout the line include nan
        data_deal = np.array(df4)
        data_deal_wish, pm_value, loc_value = data_wash(data_deal)
        fit_name = fit.split('\\')
        loc_save_path = new_loc_path + fit_name[-1] + '.jpg'
        vel_save_path = new_vel_path + fit_name[-1] + '.jpg'
        dg.cluster_loc_plotsave(loc_save_path, data_deal_wish, loc_value)
        dg.cluster_vel_plotsave(vel_save_path, data_deal_wish, pm_value)
    return 0

def deal_csv_array(file_path):
    #处理csv文件返回矩阵
    data = pd.read_csv(file_path)
    info = data.loc[:,['ra','dec']]
    dat = info.dropna()
    array = np.array(dat)
    return array

def gauss_kde_location(mix_stars, loc_value):#实现位置场的核密度估计
    loc = np.array([mix_stars[:,0],mix_stars[:,1]])
    kde = gaussian_kde(loc)
    xgrid = np.linspace(loc_value[0], loc_value[1], 60)
    ygrid = np.linspace(loc_value[2], loc_value[3], 60)
    xgrid,ygrid = np.meshgrid(xgrid,ygrid)
    z = kde.evaluate(np.vstack([xgrid.ravel(),ygrid.ravel()]))
    return z,xgrid

def save_csv_data_realpicture(file_path, save_path):#读取某一文件夹数据集，保存数据多维机器用图片到文件夹
    '''
    生成机器训练与判断数据
    '''
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    file_name = (file_path.split('\\'))[-2]
    #print (file_name)
    new_loc_path = save_path + file_name + '\\loc\\'
    new_vel_path = save_path + file_name + '\\loc0.1\\'
    mkdir(new_loc_path)
    mkdir(new_vel_path)
    for csv in filelist_open:
        print(csv)
        data_deal = deal_csv_array(csv)
        data_deal_wish, loc_value, loc_value_s = data_wash(data_deal)
        csv_name = csv.split('\\')
        loc_save_path = new_loc_path + csv_name[-1] + '_loc.jpg'
        locs_save_path = new_vel_path + csv_name[-1] + '_loc0.1.jpg'
        dg.cluster_loc_plotsave(loc_save_path, data_deal, loc_value)
        dg.cluster_loc_plotsave(locs_save_path, data_deal_wish, loc_value_s)
    return 0

def get_fits_to_array(file_path):#读取fits文件写为矩阵
    data = fits.open(file_path)
    dat = np.array(data[1].data)
    data_arr = np.zeros([dat.size, 8])
    for i in range (dat.size):
        data_arr[i][0] = dat[i][0]
        data_arr[i][1] = dat[i][1]     
#        data_arr[i][2] = dat[i][10]#pmRA
#        data_arr[i][3] = dat[i][12]#pmDE
#        data_arr[i][5] = dat[i][8]#plx
#        data_arr[i][6] = dat[i][17]#Gamg
#        data_arr[i][7] = dat[i][27]#BP-RP
    data_deal = DataFrame(data_arr)  
    df = data_deal.dropna() # to dropout the line include nan
    data_deal = np.array(df)
    return data_deal

#file_path = 'E:\\天文\\cluster_data\\正负样本\\cluster_mini\\cluster\\'
#save_path = 'E:\\天文\\cluster_data\\正负样本\\machine_use_picture_fa1\\'
#save_the_data_subpicture(file_path, save_path)
file_path = '..\\fits\\'
save_path = '..\\trainset\\'
# save_csv_data_realpicture(file_path, save_path)
if __name__ == '__main__':
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    #print (file_name)
    new_loc_path = save_path + '\\fitsClusterKDEpicture_multiKde\\'
    mkdir(new_loc_path)
    for csv in filelist_open:
        print(csv)
#        data= deal_csv_array(csv)
        data = get_fits_to_array(csv)
        datas, loc_value, loc_value_s = data_wash(data, 0.05)
        datass, loc_value, loc_value_ss = data_wash(data, 0.1)
        csv_name = csv.split('\\')
        loc_save_path = new_loc_path + csv_name[-1] + '_loc.jpg'
#        locs_save_path = new_vel_path + csv_name[-1] + '_loc0.1.tif'
        
        dataList = np.zeros((60,60,3))
        z, grid = gauss_kde_location(data, loc_value)
        z = z.reshape(grid.shape)
        z = ((z-z.min())/(z.max()-z.min()))*255
        z=255-z.astype(np.int)
        #z = np.reshape(z,(grid.shape[0],grid.shape[1],1))
        dataList[:,:,0] = z
        zs, grids = gauss_kde_location(datas, loc_value_s)
        zs = zs.reshape(grid.shape)
        zs = ((zs-zs.min())/(zs.max()-zs.min()))*255
        zs=255-zs.astype(np.int)
        #zs = np.reshape(zs,(grid.shape[0],grid.shape[1],1))
        dataList[:,:,1] = zs
        zss, grids = gauss_kde_location(datass, loc_value_ss)
        zss = zss.reshape(grid.shape)
        zss = ((zss-zss.min())/(zss.max()-zss.min()))*255
        zss=255-zss.astype(np.int)
        #zss = np.reshape(zss,(grid.shape[0],grid.shape[1],1))
        #multiKde = np.reshape(np.array([[z],[zs],[zss]]),(grid.shape[0],grid.shape[1],3))
        dataList[:,:,2] = zss
        multiKde = dataList
        cv2.imwrite(loc_save_path, multiKde)
        #cv2.imwrite(locs_save_path, zs)



